-
Notifications
You must be signed in to change notification settings - Fork 2.3k
[TRITON_KERNELS] Support sm120 / 121 via sm80 fallback #8484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
|
||
|
|
||
| @triton.constexpr_function | ||
| def cuda_capability_geq(major, minor=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what other properties are uncorrect for sm_120?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean by "other" or "incorrect properties". Without this workaround, the kernel tries to use native mxfp and TMA, assuming that sm120 has full features set of sm100. But those are the only things that are currently breaking gpt-oss on sm120 / 121.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant in addition of the checks you modified. Do you know which use of cuda_capability_geq is causing problems
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's hard to say. I've seen two kinds of errors - one is use of TMA gather4 / scatter4, and other is some shape mismatch in dot. cuda_capability_geq is used in many places and the options supported by the kernel are very broad, I don't know which of them are actually problematic. Indeed, if we want to optimize for sm120 / 121, we need a more fine-grained approach to the capability check rather than falling back everything to sm80.
For example, the determination of the weight layout is highly architecture specific: https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/layout.py#L22-L27. Even if we allowed has_native_mxfp to evaluate to True for sm120, I don't know if BlackwellMXValueLayout is compatible with the dot shape of MMAv2.
| # hopper w/ mxfp4 doesn't support TMA | ||
| can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] > 9 or bitwidth(w.dtype) != 4) | ||
| # hopper or sm120 w/ mxfp4 doesn't support TMA | ||
| can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] == 10 or bitwidth(w.dtype) != 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we need a separate helper logic? I'm pretty sure we will enable TMA on hopper at some point so this will break.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, but this one is a bit different since this is an ad-hoc check due to a kernel limitation rather than an architecture one. We could add something like target_info.supports_tma(), but that needs to return False for Hopper today, which is a bit odd. So when the kernel supports TMA for Hopper in the future, we need to update the helper anyway.
As a middle ground, how about something like this?
# hopper or sm120 w/ mxfp4 doesn't support TMA
supports_tma = [10] # Add 9 when the Hopper impl supports TMA
can_use_tma = can_use_tma and (torch.cuda.get_device_capability()[0] in supports_tma or bitwidth(w.dtype) != 4)
This way, when Hopper supports TMA, we can safely update it without breaking sm120. The condition torch.cuda.get_device_capability()[0] >= 9 might not be correct depending on how well sm120 TMA is supported by the kernel.
|
|
||
|
|
||
| @triton.constexpr_function | ||
| def cuda_capability_geq(major, minor=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant in addition of the checks you modified. Do you know which use of cuda_capability_geq is causing problems
| @triton.constexpr_function | ||
| def cuda_capability_geq(major, minor=0): | ||
| target = current_target() | ||
| if target.arch // 10 == 12 and major > 8: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand this is a workaround but the function name doesn't reflect what's really doing. sm80 and sm120 still have subtle differences in the instructions.
Is it possible to separate the logic from this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Benchmark related changes are good to me. Thanks for catching these problems!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sm80 and sm120 still have subtle differences in the instructions
Yes in terms of the architecture, but what really matters is if those differences are recognized by the compiler or the kernel. Support for sm120 in the compiler is very limited, so from the compiler / kernel perspectives, sm80 and sm120 are pretty much the same.
We could introduce another helper to distinguish those kernel / compiler limitations. The Hopper limitation on TMA #8484 (comment) is another good example. But cuda_capability_geq is already used in so many places and adding another conditions makes things even more complicated.
The pervasive use of cuda_capability_geq indicates that the kernel treats "higher compute capability" as "more features". But as of sm120 this is no longer true. Checking compute capability is also meaningless when the relevant support is not available in the compiler or the kernel. So rather than adding more ad-hoc helpers / checks, we should revisit the use of compute capability as a criteria for feature selections.
I think we need some kind of "Backend" class from which all supported SM variants are derived. We can encode all target-specific available feature sets supported by the kernel there. We can cleanly express idiosyncrasies of the kernel, like
- "SM90" backend does not support TMA with mxfp4 due to a kernel limitation, despite the support by HW
- "SM120" backend does not support native MXFP or TMA due to compiler limitation, despite the support by HW
The main motivation is to support running gpt-oss on DGX Spark (sm121). Until we properly enable mixed-precision MXFP and TMA for sm120 / sm121, we need to fallback to the sm80 compilation path.
All MoE tests pass on RTX6000 with this change:
Related issue
#8335